   
import torch
from torch import nn
import torch.nn.functional as F
    
#定义一个后验网络，其捕捉的是supportor策略的信息流动    
class PosteriorNet(nn.Module):
    def __init__(self, 
                 utter_dim, 
                 latent_dim ,
                 posterior_type ='FC', 
                 mu_type = 'share', 
                 var_type ='share', 
                 use_reparameterize=True, 
                 activation = nn.ReLU(inplace=True),
                 dropout_prob=0.1,):
        super(PosteriorNet, self).__init__()
        
        self.posterior_type = posterior_type
        self.mu_type = mu_type
        self.var_type = var_type
        self.strategy_dim = utter_dim
        self.utter_dim = utter_dim
        
        self.latent_dim = latent_dim
        self.dropout = dropout_prob
        
        self.activation = activation
        self.use_reparameterize = use_reparameterize
        
        self.posterior_supporter_encoder = nn.Sequential(nn.Linear(self.utter_dim, self.latent_dim // 4),
                                                         nn.ReLU(),
                                                         nn.Linear(self.latent_dim // 4, self.latent_dim))
        
        if posterior_type == 'FC':
            self.strategy_feature_fusion = nn.Linear(self.latent_dim + self.strategy_dim, self.latent_dim)
            self.posterior_eps_strategy = nn.Sequential(
                nn.Linear(self.latent_dim * 2, self.latent_dim // 4),
                nn.ReLU(),
                nn.Linear(self.latent_dim // 4, self.latent_dim)
                )
                
        if self.mu_type == 'share':
            self.posterior_mu_FC = nn.Linear(self.latent_dim, self.latent_dim)
                
        if self.var_type == 'share':
            self.posterior_logvar_FC = nn.Linear(self.latent_dim, self.latent_dim)
            
        # generate flow via z_t and context
        self.posterior_supporter_decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.latent_dim // 4),
            nn.ReLU(),
            nn.Linear(self.latent_dim // 4, self.latent_dim + self.utter_dim),
        )
        
    def reparameterize(self, mu, logvar, test):
    
        if not test:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(mu)
            return eps.mul(std).add_(mu)
        else:
            return mu
    
    def rec_loss_func(self, input_emb, rec_emb):
        return nn.MSELoss(reduce='mean')(input_emb, rec_emb)
    
    # The core implementation of the code has been removed, and the full code will be released upon the paper's acceptance. 
    def forward(self, strategy_cur, supporter_cur, z_po_last, mask, test=False):
        
        return full_eps_mu, full_eps_logvar, full_z_cur, rec_loss